--- title: Time Alignment with micro-tcn keywords: fastai sidebar: home_sidebar nb_path: "02_time_align.ipynb" ---
{% raw %}
{% endraw %}

Work in progress for NASH Hackathon, Dec 17, 2021

this is like the 01_td_demo notebook only we use a different dataset and generalize the dataloader a bit

Installs and imports

{% raw %}
# Next line only executes on Colab. Colab users: Please enable GPU in Edit > Notebook settings
! [ -e /content ] && pip install -Uqq pip fastai git+https://github.com/drscotthawley/fastproaudio.git

# Additional installs for this tutorial
%pip install -q fastai_minima torchsummary pyzenodo3 wandb

# Install micro-tcn and auraloss packages (from source, will take a little while)
%pip install -q wheel --ignore-requires-python git+https://github.com/csteinmetz1/micro-tcn.git  git+https://github.com/csteinmetz1/auraloss

# After this cell finishes, restart the kernel and continue below
WARNING: You are using pip version 21.3; however, version 21.3.1 is available.
You should consider upgrading via the '/home/shawley/envs/fastai/bin/python -m pip install --upgrade pip' command.
Note: you may need to restart the kernel to use updated packages.
  WARNING: Missing build requirements in pyproject.toml for git+https://github.com/csteinmetz1/auraloss.
  WARNING: The project does not specify a build backend, and pip cannot fall back to setuptools without 'wheel'.
WARNING: You are using pip version 21.3; however, version 21.3.1 is available.
You should consider upgrading via the '/home/shawley/envs/fastai/bin/python -m pip install --upgrade pip' command.
Note: you may need to restart the kernel to use updated packages.
{% endraw %} {% raw %}
from fastai.vision.all import *
from fastai.text.all import *
from fastai.callback.fp16 import *
import wandb
from fastai.callback.wandb import *
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from IPython.display import Audio 
import matplotlib.pyplot as plt
import torchsummary
from fastproaudio.core import *
from pathlib import Path
from glob import glob
import json
import re 
{% endraw %}

Dataset Generation Example

Jacob's making the real dataset, but here's a brief intuitive demo of the strategy: paste a bunch of audio samples at regular intervals ("the grid") in a long array, then do another one where but randomly perturb the locations of the pastes. Also there's a click track.

We'll use Marco's guitar strike dataset (which is really from IDMT but whatever):

{% raw %}
path_audiomdpi = get_audio_data(URLs.MARCO)
horn = path_audiomdpi / "LeslieHorn"; horn.ls()
path_dry = horn /'dry'
audio_extensions = ['.m3u', '.ram', '.au', '.snd', '.mp3','.wav']
fnames_dry = get_files(path_dry, extensions=audio_extensions)
{% endraw %}

Let's just take a look at one guitar pluck

{% raw %}
waveform, sample_rate = torchaudio.load(fnames_dry[0])
show_audio(waveform, sample_rate) 
Shape: (1, 110250), Dtype: torch.float32, Duration: 2.5 s
Max:  1.000,  Min: -0.973, Mean: -0.000, Std Dev:  0.086
{% endraw %}

And now the main stragegy of pasting in this one sample (IRL we'll use lots of them) along a track.

{% raw %}
sample = waveform[0].numpy()  # just simplify array dimensions for this demo
sample = sample[int(0.63*sample_rate):]  # chop off the silence at the front for this demo

track_length = sample_rate*5
sample_len = sample.shape[-1]
target = np.zeros(track_length)
input = np.zeros(track_length)
click = np.zeros(track_length)

grid_interval = sample_rate

n_intervals = track_length // grid_interval
for i in range(n_intervals):                 # paste samples at regular intervals
    start = grid_interval*i 
    click[start] = 1                          # click track
    end = min( start+sample_len, track_length)
    target[start:end] = sample[0:end-start]  # paste the sample on the grid
    
    # perturb up the paste location by some amount
    rand_start = max(0, start + np.random.randint(-grid_interval//2,grid_interval//2))
    rand_end = min( rand_start+sample_len, track_length )
    input[rand_start:rand_end] = sample[0:rand_end-rand_start]
{% endraw %}

There's some kind of click track that will be regarded as part of the multichannel Input:

{% raw %}
fig = plt.figure(figsize=(14, 2))
plt.plot(click)
[<matplotlib.lines.Line2D at 0x7f3d5b77b160>]
{% endraw %}

Input is randomly perturbed from the grid:

{% raw %}
fig = plt.figure(figsize=(14, 2))
plt.plot(input)
[<matplotlib.lines.Line2D at 0x7f3d5b7c8f10>]
{% endraw %}

Target is on the grid:

{% raw %}
fig = plt.figure(figsize=(14, 2))
plt.plot(target)  # target is on the grid
[<matplotlib.lines.Line2D at 0x7f3d5b72ba00>]
{% endraw %}

The job of the network is: given input and click track, produce the target.

.... this was all in mono and with only one audio sample, but IRL we'll have multiple channels of audio and a variety of audio samples.

Now Using the Real Dataset

-- pause here ---

jacob's still working on generating the dataset(s). Probably he'll put it in private Dropbox.

{% raw %}
path = Path('wherever jacob puts the data')


fnames_in = sorted(glob(str(path)+'/*/input*'))
fnames_targ = sorted(glob(str(path)+'/*/*targ*'))
ind = -1   # pick one spot in the list of files
fnames_in[ind], fnames_targ[ind]
('/home/shawley/.fastai/data/SignalTrain_LA2A_Reduced/Val/input_260_.wav',
 '/home/shawley/.fastai/data/SignalTrain_LA2A_Reduced/Val/target_260_LA2A_2c__1__85.wav')
{% endraw %}

Input audio

{% raw %}
waveform, sample_rate = torchaudio.load(fnames_in[ind])
show_audio(waveform, sample_rate)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.225,  Min: -0.218, Mean:  0.000, Std Dev:  0.038
{% endraw %}

Target output audio

{% raw %}
target, sr_targ = torchaudio.load(fnames_targ[ind])
show_audio(target, sr_targ)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.091,  Min: -0.103, Mean: -0.000, Std Dev:  0.021
{% endraw %}

Let's look at the difference.

Difference

{% raw %}
show_audio(target - waveform, sample_rate)
Shape: (1, 441000), Dtype: torch.float32, Duration: 10.0 s
Max:  0.144,  Min: -0.159, Mean: -0.000, Std Dev:  0.018
{% endraw %} {% raw %}
def get_accompanying_tracks(fn, fn_list, remove=False):
    """"Given one filename, and a list of all filenames, return a list of that filename and 
    any files it 'goes with'
    remove: remove these accompanying files from the main list.
    """
    # make a copies of fn & fn_list with all hyphen+stuff removed. 
    basename = re.sub(r'-[a-zA-Z0-9]+','', fn) 
    basename_list = [re.sub(r'-[a-zA-Z0-9]+','', x) for x in fn_list]
    
    # get indices of all elements of basename_list matching basename, return original filenames
    accompanying = [fn_list[i] for i, x in enumerate(basename_list) if x == basename]
    if remove: 
        for x in accompanying: 
            if x != fn: fn_list.remove(x)  # don't remove the file we search on though
    return accompanying # note accompanying list includes original file too
{% endraw %} {% raw %}
fn_list = ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
print(fn_list)
track = fn_list[1]
print("getting matching tracks for ",track)
tracks  = get_accompanying_tracks(fn_list[1], fn_list, remove=True)
print("Accompanying tracks are: ",tracks)
print("new list = ",fn_list) # should have the extra 21- tracks removed.
['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
getting matching tracks for  input_21-1_.wav
Accompanying tracks are:  ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav']
new list =  ['input_21-1_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
{% endraw %} {% raw %}
fn_list = ['input_21-0_.wav', 'input_21-1_.wav', 'input_21-hey_.wav', 'input_22_.wav', 'input_23_.wav', 'input_23-toms_.wav', 'input_24-0_.wav', 'input_24-kick_.wav']
fn_list_save = fn_list.copy() 
for x in fn_list:
    get_accompanying_tracks(x, fn_list, remove=True)
fn_list, fn_list_save
(['input_21-0_.wav', 'input_22_.wav', 'input_23_.wav', 'input_24-0_.wav'],
 ['input_21-0_.wav',
  'input_21-1_.wav',
  'input_21-hey_.wav',
  'input_22_.wav',
  'input_23_.wav',
  'input_23-toms_.wav',
  'input_24-0_.wav',
  'input_24-kick_.wav'])
{% endraw %}

Dataset class and Dataloaders

here we modify Christian's SignalTrainLA2ADataset class

The original dataset class that Christian made, for which we "pack" params and inputs together. This will be loading multichannel wav files

{% raw %}
from microtcn.data import SignalTrainLA2ADataset

class SignalTrainLA2ADataset_fastai(SignalTrainLA2ADataset):
    "For fastai's sake, have getitem pack the inputs and params together"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __getitem__(self, idx):
        input, target, params = super().__getitem__(idx)
        return torch.cat((input,params),dim=-1), target   # pack input and params together
{% endraw %}

Dataset for loading multiple mono files and packing them together as multichannel:

{% raw %}
'''
class MonoToMCDataset(torch.utils.data.Dataset):
    """
    UPDATE: turns out we're going to stick to Christian's original dataloader class and just use 
    conversion scripts to pack or unpack mono WAV files into multichannel WAV files.
    ----
    
    Modifying Steinmetz' micro-tcn code so we can load the kind of multichannel audio we want.
    The difference is that now, we group files that are similar except for a hyphen-designation, 
    e..g. input_235-1_.wav, input_235-2_.wav get read into one tensor.
    
    The 'trick' will be that we only ever store one filename 'version' of a group of files, but whenever we 
    want to try to load that file, we will also grab all its associated files. 
    
    Like SignalTrain LA2A dataset only more general"""
    def __init__(self, root_dir, subset="train", length=16384, preload=False, half=True, fraction=1.0, use_soundfile=False):
        """
        Args:
            root_dir (str): Path to the root directory of the SignalTrain dataset.
            subset (str, optional): Pull data either from "train", "val", "test", or "full" subsets. (Default: "train")
            length (int, optional): Number of samples in the returned examples. (Default: 40)
            preload (bool, optional): Read in all data into RAM during init. (Default: False)
            half (bool, optional): Store the float32 audio as float16. (Default: True)
            fraction (float, optional): Fraction of the data to load from the subset. (Default: 1.0)
            use_soundfile (bool, optional): Use the soundfile library to load instead of torchaudio. (Default: False)
        """
        self.root_dir = root_dir
        self.subset = subset
        self.length = length
        self.preload = preload
        self.half = half
        self.fraction = fraction
        self.use_soundfile = use_soundfile

        if self.subset == "full":
            self.target_files = glob.glob(os.path.join(self.root_dir, "**", "target_*.wav"))
            self.input_files  = glob.glob(os.path.join(self.root_dir, "**", "input_*.wav"))
        else:
            # get all the target files files in the directory first
            self.target_files = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "target_*.wav"))
            self.input_files  = glob.glob(os.path.join(self.root_dir, self.subset.capitalize(), "input_*.wav"))

        self.examples = [] 
        self.minutes = 0  # total number of hours of minutes in the subset

        # ensure that the sets are ordered correctlty
        self.target_files.sort()
        self.input_files.sort()

        # get the parameters 
        self.params = [(float(f.split("__")[1].replace(".wav","")), float(f.split("__")[2].replace(".wav",""))) for f in self.target_files]

        
        # SHH: HERE is where we'll package similar hyphen-designated files together. list comprehension here wouldn't be good btw.
        # essentially we are removing 'duplicates'. the first file of each group will be the signifier of all of them
        self.target_files_all, self.input_files_all = self.target_files.copy(), self.input_files.copy() # save a copy of original list
        for x in self.target_files:  # remove extra accompanying tracks from main list that loader will use
            get_accompanying_tracks(x, self.target_files, remove=True)
        for x in self.input_files:
            get_accompanying_tracks(x, self.input_files, remove=True)
        # make a dict that will map main file name to list of accompanying files (including itself)
        self.target_accomp = {f: get_accompanying_tracks(f, self.target_files_all) for f in self.target_files}
        self.input_accomp = {f: get_accompanying_tracks(f, self.input_files_all) for f in self.input_files}
        
        # loop over files to count total length
        for idx, (tfile, ifile, params) in enumerate(zip(self.target_files, self.input_files, self.params)):

            ifile_id = int(os.path.basename(ifile).split("_")[1])
            tfile_id = int(os.path.basename(tfile).split("_")[1])
            if ifile_id != tfile_id:
                raise RuntimeError(f"Found non-matching file ids: {ifile_id} != {tfile_id}! Check dataset.")

            md = torchaudio.info(tfile)
            num_frames = md.num_frames

            if self.preload:
                sys.stdout.write(f"* Pre-loading... {idx+1:3d}/{len(self.target_files):3d} ...\r")
                sys.stdout.flush()
                
                input, sr  = self.load_accompanying(ifile, self.input_accomp)
                target, sr = self.load_accompanying(tfile, self.target_accomp)

                num_frames = int(np.min([input.shape[-1], target.shape[-1]]))
                if input.shape[-1] != target.shape[-1]:
                    print(os.path.basename(ifile), input.shape[-1], os.path.basename(tfile), target.shape[-1])
                    raise RuntimeError("Found potentially corrupt file!")
                if self.half:
                    input = input.half()
                    target = target.half()
            else:
                input = None
                target = None

            # create one entry for each patch
            self.file_examples = []
            for n in range((num_frames // self.length)):
                offset = int(n * self.length)
                end = offset + self.length
                self.file_examples.append({"idx": idx, 
                                           "target_file" : tfile,
                                           "input_file" : ifile,
                                           "input_audio" : input[:,offset:end] if input is not None else None,
                                           "target_audio" : target[:,offset:end] if input is not None else None,
                                           "params" : params,
                                           "offset": offset,
                                           "frames" : num_frames})

            # add to overall file examples
            self.examples += self.file_examples
        
        # use only a fraction of the subset data if applicable
        if self.subset == "train":
            classes = set([ex['params'] for ex in self.examples])
            n_classes = len(classes) # number of unique compressor configurations
            fraction_examples = int(len(self.examples) * self.fraction)
            n_examples_per_class = int(fraction_examples / n_classes)
            n_min_total = ((self.length * n_examples_per_class * n_classes) / md.sample_rate) / 60 
            n_min_per_class = ((self.length * n_examples_per_class) / md.sample_rate) / 60 
            print(sorted(classes))
            print(f"Total Examples: {len(self.examples)}     Total classes: {n_classes}")
            print(f"Fraction examples: {fraction_examples}    Examples/class: {n_examples_per_class}")
            print(f"Training with {n_min_per_class:0.2f} min per class    Total of {n_min_total:0.2f} min")

            if n_examples_per_class <= 0: 
                raise ValueError(f"Fraction `{self.fraction}` set too low. No examples selected.")

            sampled_examples = []

            for config_class in classes: # select N examples from each class
                class_examples = [ex for ex in self.examples if ex["params"] == config_class]
                example_indices = np.random.randint(0, high=len(class_examples), size=n_examples_per_class)
                class_examples = [class_examples[idx] for idx in example_indices]
                extra_factor = int(1/self.fraction)
                sampled_examples += class_examples * extra_factor

            self.examples = sampled_examples

        self.minutes = ((self.length * len(self.examples)) / md.sample_rate) / 60 

        # we then want to get the input files
        print(f"Located {len(self.examples)} examples totaling {self.minutes:0.2f} min in the {self.subset} subset.")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        if self.preload:
            audio_idx = self.examples[idx]["idx"]
            offset = self.examples[idx]["offset"]
            input = self.examples[idx]["input_audio"]
            target = self.examples[idx]["target_audio"]
        else:
            offset = self.examples[idx]["offset"] 
            input_name = self.examples[idx]["input_file"]
            target_name = self.examples[idx]["target_file"]
            input = torch.empty((len(self.input_accomp[input_name]), self.length))
            for c, fname in enumerate(self.input_accomp[input_name]):
                input[c], sr  = torchaudio.load(fname, 
                                            num_frames=self.length, 
                                            frame_offset=offset, 
                                            normalize=False)
            target = torch.empty((len(self.target_accomp[target_name]), self.length))
            for c, fname in enumerate(self.target_accomp[target_name]):
                target[c], sr = torchaudio.load(fname, 
                                        num_frames=self.length, 
                                        frame_offset=offset, 
                                        normalize=False)
            if self.half:
                input = input.half()
                target = target.half()

        # at random with p=0.5 flip the phase 
        if np.random.rand() > 0.5:
            input *= -1
            target *= -1

        # then get the tuple of parameters
        params = torch.tensor(self.examples[idx]["params"]).unsqueeze(0)
        params[:,1] /= 100

        return input, target, params

    def load(self, filename):
        if self.use_soundfile:
            x, sr = sf.read(filename, always_2d=True)
            x = torch.tensor(x.T)
        else:
            x, sr = torchaudio.load(filename, normalize=False)
        return x, sr
    
    def load_accompanying(self, filename, accomp_dict):
        accomp = accomp_dict[filename]
        self.num_channels = len(accomp)
        md = torchaudio.info(filename)   # TODO:fix: assumes all accompanying tracks are the same shape, etc! 
        num_frames = md.num_frames
        data = torch.empty((self.num_channels,num_frames))
        for c, afile in enumerate(accomp):
            data[c], sr  = self.load(afile)
        return data, sr
'''
{% endraw %} {% raw %}
class Args(object):  # stand-in for parseargs. these are all micro-tcn defaults
    model_type ='tcn'
    root_dir = str(path)
    preload = False
    sample_rate = 44100
    shuffle = True
    train_subset = 'train'
    val_subset = 'val'
    train_length = 65536
    train_fraction = 1.0
    eval_length = 131072
    batch_size = 8   # original is 32, my laptop needs smaller, esp. w/o half precision
    num_workers = 4
    precision = 32  # LEAVE AS 32 FOR NOW: HALF PRECISION (16) NOT WORKING YET -SHH
    n_params = 2
    
args = Args()

#if args.precision == 16:  torch.set_default_dtype(torch.float16)

# setup the dataloaders
train_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    subset=args.train_subset, 
                    fraction=args.train_fraction,
                    half=True if args.precision == 16 else False, 
                    preload=args.preload, 
                    length=args.train_length)

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                    shuffle=args.shuffle,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)

val_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    preload=args.preload,
                    half=True if args.precision == 16 else False,
                    subset=args.val_subset,
                    length=args.eval_length)

val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                    shuffle=False,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)
[(0.0, 0.0), (0.0, 5.0), (0.0, 15.0), (0.0, 20.0), (0.0, 25.0), (0.0, 30.0), (0.0, 35.0), (0.0, 40.0), (0.0, 45.0), (0.0, 55.0), (0.0, 60.0), (0.0, 65.0), (0.0, 70.0), (0.0, 75.0), (0.0, 80.0), (0.0, 85.0), (0.0, 90.0), (0.0, 95.0), (0.0, 100.0), (1.0, 0.0), (1.0, 5.0), (1.0, 15.0), (1.0, 20.0), (1.0, 25.0), (1.0, 30.0), (1.0, 35.0), (1.0, 40.0), (1.0, 45.0), (1.0, 50.0), (1.0, 55.0), (1.0, 60.0), (1.0, 65.0), (1.0, 75.0), (1.0, 80.0), (1.0, 85.0), (1.0, 90.0), (1.0, 95.0), (1.0, 100.0)]
Total Examples: 396     Total classes: 38
Fraction examples: 396    Examples/class: 10
Training with 0.25 min per class    Total of 9.41 min
Located 380 examples totaling 9.41 min in the train subset.
Located 45 examples totaling 2.23 min in the val subset.
{% endraw %}

If the user requested fp16 precision then we need to install NVIDIA apex:

{% raw %}
if False and args.precision == 16:
    %pip install -q --disable-pip-version-check --no-cache-dir git+https://github.com/NVIDIA/apex
    from apex.fp16_utils import convert_network
{% endraw %}

Define the model(s)

Christian defined a lot of models. We'll do the TCN-300 and the LSTM.

{% raw %}
from microtcn.tcn_bare import TCNModel as TCNModel
#from microtcn.lstm import LSTMModel # actually the LSTM depends on a lot of Lightning stuff, so we'll skip that
from microtcn.utils import center_crop, causal_crop

class TCNModel_fastai(TCNModel):
    "For fastai's sake, unpack the inputs and params"
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def forward(self, x, p=None):
        if (p is None) and (self.nparams > 0):  # unpack the params if needed
            assert len(list(x.size())) == 3   # sanity check 
            x, p = x[:,:,0:-self.nparams], x[:,:,-self.nparams:]
        return super().forward(x, p=p)
{% endraw %} {% raw %}
# micro-tcn defines several different model configurations. I just chose one of them. 
train_configs = [
      {"name" : "TCN-300",
     "model_type" : "tcn",
     "nblocks" : 10,
     "dilation_growth" : 2,
     "kernel_size" : 15,
     "causal" : False,
     "train_fraction" : 1.00,
     "batch_size" : args.batch_size
    }
]

dict_args = train_configs[0]
dict_args["nparams"] = 2

model = TCNModel_fastai(**dict_args)
dtype = torch.float32
{% endraw %}

Let's take a look at the model:

{% raw %}
# this summary allows one to compare the original TCNModel with the TCNModel_fastai
if type(model) == TCNModel_fastai:
    torchsummary.summary(model, [(1,args.train_length)], device="cpu")
else:
    torchsummary.summary(model, [(1,args.train_length),(1,2)], device="cpu")
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 16]              48
              ReLU-2                [-1, 1, 16]               0
            Linear-3                [-1, 1, 32]             544
              ReLU-4                [-1, 1, 32]               0
            Linear-5                [-1, 1, 32]           1,056
              ReLU-6                [-1, 1, 32]               0
            Conv1d-7            [-1, 32, 65520]             480
            Linear-8                [-1, 1, 64]           2,112
       BatchNorm1d-9            [-1, 32, 65520]               0
             FiLM-10            [-1, 32, 65520]               0
            PReLU-11            [-1, 32, 65520]              32
           Conv1d-12            [-1, 32, 65534]              32
         TCNBlock-13            [-1, 32, 65520]               0
           Conv1d-14            [-1, 32, 65492]          15,360
           Linear-15                [-1, 1, 64]           2,112
      BatchNorm1d-16            [-1, 32, 65492]               0
             FiLM-17            [-1, 32, 65492]               0
            PReLU-18            [-1, 32, 65492]              32
           Conv1d-19            [-1, 32, 65520]              32
         TCNBlock-20            [-1, 32, 65492]               0
           Conv1d-21            [-1, 32, 65436]          15,360
           Linear-22                [-1, 1, 64]           2,112
      BatchNorm1d-23            [-1, 32, 65436]               0
             FiLM-24            [-1, 32, 65436]               0
            PReLU-25            [-1, 32, 65436]              32
           Conv1d-26            [-1, 32, 65492]              32
         TCNBlock-27            [-1, 32, 65436]               0
           Conv1d-28            [-1, 32, 65324]          15,360
           Linear-29                [-1, 1, 64]           2,112
      BatchNorm1d-30            [-1, 32, 65324]               0
             FiLM-31            [-1, 32, 65324]               0
            PReLU-32            [-1, 32, 65324]              32
           Conv1d-33            [-1, 32, 65436]              32
         TCNBlock-34            [-1, 32, 65324]               0
           Conv1d-35            [-1, 32, 65100]          15,360
           Linear-36                [-1, 1, 64]           2,112
      BatchNorm1d-37            [-1, 32, 65100]               0
             FiLM-38            [-1, 32, 65100]               0
            PReLU-39            [-1, 32, 65100]              32
           Conv1d-40            [-1, 32, 65324]              32
         TCNBlock-41            [-1, 32, 65100]               0
           Conv1d-42            [-1, 32, 64652]          15,360
           Linear-43                [-1, 1, 64]           2,112
      BatchNorm1d-44            [-1, 32, 64652]               0
             FiLM-45            [-1, 32, 64652]               0
            PReLU-46            [-1, 32, 64652]              32
           Conv1d-47            [-1, 32, 65100]              32
         TCNBlock-48            [-1, 32, 64652]               0
           Conv1d-49            [-1, 32, 63756]          15,360
           Linear-50                [-1, 1, 64]           2,112
      BatchNorm1d-51            [-1, 32, 63756]               0
             FiLM-52            [-1, 32, 63756]               0
            PReLU-53            [-1, 32, 63756]              32
           Conv1d-54            [-1, 32, 64652]              32
         TCNBlock-55            [-1, 32, 63756]               0
           Conv1d-56            [-1, 32, 61964]          15,360
           Linear-57                [-1, 1, 64]           2,112
      BatchNorm1d-58            [-1, 32, 61964]               0
             FiLM-59            [-1, 32, 61964]               0
            PReLU-60            [-1, 32, 61964]              32
           Conv1d-61            [-1, 32, 63756]              32
         TCNBlock-62            [-1, 32, 61964]               0
           Conv1d-63            [-1, 32, 58380]          15,360
           Linear-64                [-1, 1, 64]           2,112
      BatchNorm1d-65            [-1, 32, 58380]               0
             FiLM-66            [-1, 32, 58380]               0
            PReLU-67            [-1, 32, 58380]              32
           Conv1d-68            [-1, 32, 61964]              32
         TCNBlock-69            [-1, 32, 58380]               0
           Conv1d-70            [-1, 32, 51212]          15,360
           Linear-71                [-1, 1, 64]           2,112
      BatchNorm1d-72            [-1, 32, 51212]               0
             FiLM-73            [-1, 32, 51212]               0
            PReLU-74            [-1, 32, 51212]              32
           Conv1d-75            [-1, 32, 58380]              32
         TCNBlock-76            [-1, 32, 51212]               0
           Conv1d-77             [-1, 1, 51212]              33
================================================================
Total params: 162,161
Trainable params: 162,161
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.25
Forward/backward pass size (MB): 922.11
Params size (MB): 0.62
Estimated Total Size (MB): 922.98
----------------------------------------------------------------
{% endraw %}

Getting the model into fastai form

Zach Mueller made a very helpful fastai_minima package that we'll use, and follow his instructions.

TODO: Zach says I should either use fastai or fastai_minima, not mix them like I'm about to do. But what I have below is the only thing that works right now. ;-)

{% raw %}
# I guess we could've imported these up at the top of the notebook...
from torch import optim
from fastai_minima.optimizer import OptimWrapper
#from fastai_minima.learner import Learner  # this doesn't include lr_find()
from fastai.learner import Learner
from fastai_minima.learner import DataLoaders
#from fastai_minima.callback.training_utils import CudaCallback, ProgressCallback # note sure if I need these
{% endraw %} {% raw %}
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, **kwargs))

dls = DataLoaders(train_dataloader, val_dataloader)
{% endraw %}

Checking: Let's make sure the Dataloaders are working

{% raw %}
if args.precision==16: 
    dtype = torch.float16
    model = convert_network(model, torch.float16)

model = model.to('cuda:0')
if type(model) == TCNModel_fastai:
    print("We're using Hawley's modified code")
    packed, targ = dls.one_batch()
    inp, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
    pred = model.forward(packed.to('cuda:0', dtype=dtype))
else:
    print("We're using Christian's version of Dataloader and model")
    inp, targ, params = dls.one_batch()
    pred = model.forward(inp.to('cuda:0',dtype=dtype), p=params.to('cuda:0', dtype=dtype))
print(f"input  = {inp.size()}\ntarget = {targ.size()}\nparams = {params.size()}\npred   = {pred.size()}")
We're using Hawley's modified code
input  = torch.Size([8, 1, 65536])
target = torch.Size([8, 1, 65536])
params = torch.Size([8, 1, 2])
pred   = torch.Size([8, 1, 51214])
{% endraw %}

We can make the pred and target the same length by cropping when we compute the loss:

{% raw %}
class Crop_Loss:
    "Crop target size to match preds"
    def __init__(self, axis=-1, causal=False, reduction="mean", func=nn.L1Loss):
        store_attr()
        self.loss_func = func()
    def __call__(self, pred, targ):
        targ = causal_crop(targ, pred.shape[-1]) if self.causal else center_crop(targ, pred.shape[-1])
        #pred, targ = TensorBase(pred), TensorBase(targ)
        assert pred.shape == targ.shape, f'pred.shape = {pred.shape} but targ.shape = {targ.shape}'
        return self.loss_func(pred,targ).flatten().mean() if self.reduction == "mean" else loss(pred,targ).flatten().sum()
    

# we could add a metric like MSE if we want
def crop_mse(pred, targ, causal=False): 
    targ = causal_crop(targ, pred.shape[-1]) if causal else center_crop(targ, pred.shape[-1])
    return ((pred - targ)**2).mean()
{% endraw %}

Enable logging with WandB:

{% raw %}
wandb.login()
wandb: Currently logged in as: drscotthawley (use `wandb login --relogin` to force relogin)
True
{% endraw %}

Define the fastai Learner and callbacks

We're going to add a new custom WandBAudio callback futher below, that we'll uses when we call fit().

WandBAudio Callback

In order to log audio samples, let's write our own audio-logging callback for fastai:

{% raw %}
class WandBAudio(Callback):
    """Progress-like callback: log audio to WandB"""
    order = ProgressCallback.order+1
    def __init__(self, n_preds=5, sample_rate=44100):
        store_attr()

    def after_epoch(self):  
        if not self.learn.training:
            with torch.no_grad():
                preds, targs = [x.detach().cpu().numpy().copy() for x in [self.learn.pred, self.learn.y]]
            log_dict = {}
            for i in range(min(self.n_preds, preds.shape[0])): # note wandb only supports mono
                    log_dict[f"preds_{i}"] = wandb.Audio(preds[i,0,:], caption=f"preds_{i}", sample_rate=self.sample_rate)
            wandb.log(log_dict)
{% endraw %}

Learner and wandb init

{% raw %}
wandb.init(project='micro-tcn-fastai')#  no name, name=json.dumps(dict_args))

learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func,
               cbs= [WandbCallback()])
{% endraw %}

Train the model

We can use the fastai learning rate finder to suggest a learning rate:

{% raw %}
learn.lr_find(end_lr=0.1) 
SuggestedLRs(valley=0.0006918309954926372)
{% endraw %}

And now we'll train using the one-cycle LR schedule, with the WandBAudio callback. (Ignore any warning messages)

{% raw %}
epochs = 20  # change to 50 for better results but a longer wait
learn.fit_one_cycle(epochs, lr_max=3e-3, cbs=WandBAudio(sample_rate=args.sample_rate))
Could not gather input dimensions
WandbCallback requires use of "SaveModelCallback" to log best model
WandbCallback was not able to prepare a DataLoader for logging prediction samples -> 
epoch train_loss valid_loss crop_mse time
0 0.143242 0.098410 0.020299 00:06
1 0.096335 0.061745 0.007963 00:05
2 0.065788 0.035349 0.003570 00:05
3 0.045120 0.027977 0.001921 00:05
4 0.034311 0.023991 0.001443 00:05
5 0.026962 0.020367 0.001035 00:06
6 0.023846 0.020088 0.000883 00:05
7 0.021708 0.015346 0.000704 00:06
8 0.019866 0.026435 0.001117 00:06
9 0.017529 0.012842 0.000533 00:05
10 0.016500 0.013006 0.000504 00:05
11 0.015390 0.011723 0.000425 00:06
12 0.014275 0.012459 0.000437 00:06
13 0.013890 0.012470 0.000408 00:05
14 0.013401 0.013570 0.000454 00:05
15 0.012933 0.011421 0.000390 00:06
16 0.012545 0.010564 0.000362 00:05
17 0.012153 0.011395 0.000392 00:05
18 0.011879 0.010478 0.000356 00:05
19 0.011740 0.010412 0.000361 00:06
{% endraw %} {% raw %}
wandb.finish() # call wandb.finish() after training or your logs may be incomplete

Waiting for W&B process to finish, PID 1852379... (success).

Run history:


crop_mse█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dampening_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr_0▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0██▇▆▅▄▃▂▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇█████
nesterov_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss█▆▄▃▃▃▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss█▇▅▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss█▅▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁
wd_0▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

Run summary:


crop_mse0.00036
dampening_00
epoch20
lr_00.0
mom_00.95
nesterov_0False
raw_loss0.01222
train_loss0.01174
valid_loss0.01041
wd_00
Synced 5 W&B file(s), 100 media file(s), 0 artifact file(s) and 0 other file(s)
Synced fresh-salad-56: https://wandb.ai/drscotthawley/micro-tcn-fastai/runs/9w1h46em
Find logs at: ./wandb/run-20211025_091818-9w1h46em/logs
{% endraw %} {% raw %}
learn.save('micro-tcn-fastai')
Path('models/micro-tcn-fastai.pth')
{% endraw %}

Go check out the resulting run logs, graphs, and audio samples at https://wandb.ai/drscotthawley/micro-tcn-fastai, or... lemme see if I can embed some results below:

...ok it looks like the WandB results iframe (with cool graphs & audio) is getting filtered out of the docs (by nbdev and/or jekyll), but if you open this notebook file -- e.g. click the "Open in Colab" badge at the top -- then scroll down and you'll see the report. Or just go to the WandB link posted above!

TODO: Inference / Evaluation / Analysis

Load in the testing data

{% raw %}
test_dataset = SignalTrainLA2ADataset_fastai(args.root_dir, 
                    preload=args.preload,
                    half=True if args.precision == 16 else False,
                    subset='test',
                    length=args.eval_length)

test_dataloader = torch.utils.data.DataLoader(test_dataset, 
                    shuffle=False,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                    pin_memory=True)

learn = Learner(dls, model, loss_func=Crop_Loss(), metrics=crop_mse, opt_func=opt_func, cbs=[])
learn.load('micro-tcn-fastai')
Located 9 examples totaling 0.45 min in the test subset.
<fastai.learner.Learner at 0x7fcff933b430>
{% endraw %}

^^ 9 examples? I thought there were only 3:

{% raw %}
!ls {path}/Test
input_235_.wav	input_259_.wav		       target_256_LA2A_2c__1__65.wav
input_256_.wav	target_235_LA2A_2c__0__65.wav  target_259_LA2A_2c__1__80.wav
{% endraw %}

...Ok I don't understand that yet. Moving on:

Let's get some predictions from the model. Note that the length of these predictions will greater than in training, because we specified them differently:

{% raw %}
print(args.train_length, args.eval_length)
65536 131072
{% endraw %}

Handy routine to grab some data and run it through the model to get predictions:

{% raw %}
def get_pred_batch(dataloader, crop_target=True, causal=False):
    packed, target = next(iter(dataloader))
    input, params = packed[:,:,0:-dict_args['nparams']], packed[:,:,-dict_args['nparams']:]
    pred = model.forward(packed.to('cuda:0', dtype=dtype))
    if crop_target: target = causal_crop(target, pred.shape[-1]) if causal else center_crop(target, pred.shape[-1])
    input, params, target, pred = [x.detach().cpu() for x in [input, params, target, pred]]
    return input, params, target, pred
{% endraw %} {% raw %}
input, params, target, pred = get_pred_batch(test_dataloader, causal=dict_args['causal'])
i = 0  # just look at the first element
print(f"------- i = {i} ---------\n")
print(f"prediction:")
show_audio(pred[i], sample_rate)
------- i = 0 ---------

prediction:
Shape: (1, 116750), Dtype: torch.float32, Duration: 2.647392290249433 s
Max:  0.139,  Min: -0.147, Mean:  0.000, Std Dev:  0.037
{% endraw %} {% raw %}
print(f"target:")
show_audio(target[i], sample_rate)
target:
Shape: (1, 116750), Dtype: torch.float32, Duration: 2.647392290249433 s
Max:  0.215,  Min: -0.202, Mean:  0.000, Std Dev:  0.053
{% endraw %}

TODO: More. We're not finished. I'll come back and add more to this later.

Deployment / Plugins

Check out Christian's GitHub page for micro-tcn where he provides instructions and JUCE files by which to render the model as an audio plugin. Pretty sure you can only do this with the causal models, which I didn't include -- yet!